1   package org.apache.lucene.search.suggest.analyzing;
2   
3   /*
4    * Licensed to the Apache Software Foundation (ASF) under one or more
5    * contributor license agreements.  See the NOTICE file distributed with
6    * this work for additional information regarding copyright ownership.
7    * The ASF licenses this file to You under the Apache License, Version 2.0
8    * (the "License"); you may not use this file except in compliance with
9    * the License.  You may obtain a copy of the License at
10   *
11   *     http://www.apache.org/licenses/LICENSE-2.0
12   *
13   * Unless required by applicable law or agreed to in writing, software
14   * distributed under the License is distributed on an "AS IS" BASIS,
15   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16   * See the License for the specific language governing permissions and
17   * limitations under the License.
18   */
19  
20  import java.io.IOException;
21  import java.io.InputStream;
22  import java.io.OutputStream;
23  import java.nio.file.Files;
24  import java.nio.file.Path;
25  import java.util.ArrayList;
26  import java.util.Collections;
27  import java.util.Comparator;
28  import java.util.HashMap;
29  import java.util.HashSet;
30  import java.util.List;
31  import java.util.Locale;
32  import java.util.Map;
33  import java.util.Set;
34  
35  import org.apache.lucene.analysis.Analyzer;
36  import org.apache.lucene.analysis.MockAnalyzer;
37  import org.apache.lucene.analysis.MockTokenizer;
38  import org.apache.lucene.analysis.Tokenizer;
39  import org.apache.lucene.analysis.core.StopFilter;
40  import org.apache.lucene.analysis.util.CharArraySet;
41  import org.apache.lucene.document.Document;
42  import org.apache.lucene.search.suggest.Lookup.LookupResult;
43  import org.apache.lucene.search.suggest.Input;
44  import org.apache.lucene.search.suggest.InputArrayIterator;
45  import org.apache.lucene.search.suggest.InputIterator;
46  import org.apache.lucene.util.BytesRef;
47  import org.apache.lucene.util.LineFileDocs;
48  import org.apache.lucene.util.LuceneTestCase;
49  import org.apache.lucene.util.TestUtil;
50  import org.junit.Ignore;
51  
52  public class TestFreeTextSuggester extends LuceneTestCase {
53  
54    public void testBasic() throws Exception {
55      Iterable<Input> keys = AnalyzingSuggesterTest.shuffle(
56          new Input("foo bar baz blah", 50),
57          new Input("boo foo bar foo bee", 20)
58      );
59  
60      Analyzer a = new MockAnalyzer(random());
61      FreeTextSuggester sug = new FreeTextSuggester(a, a, 2, (byte) 0x20);
62      sug.build(new InputArrayIterator(keys));
63      assertEquals(2, sug.getCount());
64  
65      for(int i=0;i<2;i++) {
66  
67        // Uses bigram model and unigram backoff:
68        assertEquals("foo bar/0.67 foo bee/0.33 baz/0.04 blah/0.04 boo/0.04",
69                     toString(sug.lookup("foo b", 10)));
70  
71        // Uses only bigram model:
72        assertEquals("foo bar/0.67 foo bee/0.33",
73                     toString(sug.lookup("foo ", 10)));
74  
75        // Uses only unigram model:
76        assertEquals("foo/0.33",
77                     toString(sug.lookup("foo", 10)));
78  
79        // Uses only unigram model:
80        assertEquals("bar/0.22 baz/0.11 bee/0.11 blah/0.11 boo/0.11",
81                     toString(sug.lookup("b", 10)));
82  
83        // Try again after save/load:
84        Path tmpDir = createTempDir("FreeTextSuggesterTest");
85  
86        Path path = tmpDir.resolve("suggester");
87  
88        OutputStream os = Files.newOutputStream(path);
89        sug.store(os);
90        os.close();
91  
92        InputStream is = Files.newInputStream(path);
93        sug = new FreeTextSuggester(a, a, 2, (byte) 0x20);
94        sug.load(is);
95        is.close();
96        assertEquals(2, sug.getCount());
97      }
98      a.close();
99    }
100 
101   public void testIllegalByteDuringBuild() throws Exception {
102     // Default separator is INFORMATION SEPARATOR TWO
103     // (0x1e), so no input token is allowed to contain it
104     Iterable<Input> keys = AnalyzingSuggesterTest.shuffle(
105         new Input("foo\u001ebar baz", 50)
106     );
107     Analyzer analyzer = new MockAnalyzer(random());
108     FreeTextSuggester sug = new FreeTextSuggester(analyzer);
109     try {
110       sug.build(new InputArrayIterator(keys));
111       fail("did not hit expected exception");
112     } catch (IllegalArgumentException iae) {
113       // expected
114     }
115     analyzer.close();
116   }
117 
118   public void testIllegalByteDuringQuery() throws Exception {
119     // Default separator is INFORMATION SEPARATOR TWO
120     // (0x1e), so no input token is allowed to contain it
121     Iterable<Input> keys = AnalyzingSuggesterTest.shuffle(
122         new Input("foo bar baz", 50)
123     );
124     Analyzer analyzer = new MockAnalyzer(random());
125     FreeTextSuggester sug = new FreeTextSuggester(analyzer);
126     sug.build(new InputArrayIterator(keys));
127 
128     try {
129       sug.lookup("foo\u001eb", 10);
130       fail("did not hit expected exception");
131     } catch (IllegalArgumentException iae) {
132       // expected
133     }
134     analyzer.close();
135   }
136 
137   @Ignore
138   public void testWiki() throws Exception {
139     final LineFileDocs lfd = new LineFileDocs(null, "/lucenedata/enwiki/enwiki-20120502-lines-1k.txt", false);
140     // Skip header:
141     lfd.nextDoc();
142     Analyzer analyzer = new MockAnalyzer(random());
143     FreeTextSuggester sug = new FreeTextSuggester(analyzer);
144     sug.build(new InputIterator() {
145 
146         private int count;
147 
148         @Override
149         public long weight() {
150           return 1;
151         }
152 
153         @Override
154         public BytesRef next() {
155           Document doc;
156           try {
157             doc = lfd.nextDoc();
158           } catch (IOException ioe) {
159             throw new RuntimeException(ioe);
160           }
161           if (doc == null) {
162             return null;
163           }
164           if (count++ == 10000) {
165             return null;
166           }
167           return new BytesRef(doc.get("body"));
168         }
169 
170         @Override
171         public BytesRef payload() {
172           return null;
173         }
174 
175         @Override
176         public boolean hasPayloads() {
177           return false;
178         }
179 
180         @Override
181         public Set<BytesRef> contexts() {
182           return null;
183         }
184 
185         @Override
186         public boolean hasContexts() {
187           return false;
188         }
189       });
190     if (VERBOSE) {
191       System.out.println(sug.ramBytesUsed() + " bytes");
192 
193       List<LookupResult> results = sug.lookup("general r", 10);
194       System.out.println("results:");
195       for(LookupResult result : results) {
196         System.out.println("  " + result);
197       }
198     }
199     analyzer.close();
200   }
201 
202   // Make sure you can suggest based only on unigram model:
203   public void testUnigrams() throws Exception {
204     Iterable<Input> keys = AnalyzingSuggesterTest.shuffle(
205         new Input("foo bar baz blah boo foo bar foo bee", 50)
206     );
207 
208     Analyzer a = new MockAnalyzer(random());
209     FreeTextSuggester sug = new FreeTextSuggester(a, a, 1, (byte) 0x20);
210     sug.build(new InputArrayIterator(keys));
211     // Sorts first by count, descending, second by term, ascending
212     assertEquals("bar/0.22 baz/0.11 bee/0.11 blah/0.11 boo/0.11",
213                  toString(sug.lookup("b", 10)));
214     a.close();
215   }
216 
217   // Make sure the last token is not duplicated
218   public void testNoDupsAcrossGrams() throws Exception {
219     Iterable<Input> keys = AnalyzingSuggesterTest.shuffle(
220         new Input("foo bar bar bar bar", 50)
221     );
222     Analyzer a = new MockAnalyzer(random());
223     FreeTextSuggester sug = new FreeTextSuggester(a, a, 2, (byte) 0x20);
224     sug.build(new InputArrayIterator(keys));
225     assertEquals("foo bar/1.00",
226                  toString(sug.lookup("foo b", 10)));
227     a.close();
228   }
229 
230   // Lookup of just empty string produces unicode only matches:
231   public void testEmptyString() throws Exception {
232     Iterable<Input> keys = AnalyzingSuggesterTest.shuffle(
233         new Input("foo bar bar bar bar", 50)
234     );
235     Analyzer a = new MockAnalyzer(random());
236     FreeTextSuggester sug = new FreeTextSuggester(a, a, 2, (byte) 0x20);
237     sug.build(new InputArrayIterator(keys));
238     try {
239       sug.lookup("", 10);
240       fail("did not hit exception");
241     } catch (IllegalArgumentException iae) {
242       // expected
243     }
244     a.close();
245   }
246 
247   // With one ending hole, ShingleFilter produces "of _" and
248   // we should properly predict from that:
249   public void testEndingHole() throws Exception {
250     // Just deletes "of"
251     Analyzer a = new Analyzer() {
252         @Override
253         public TokenStreamComponents createComponents(String field) {
254           Tokenizer tokenizer = new MockTokenizer();
255           CharArraySet stopSet = StopFilter.makeStopSet("of");
256           return new TokenStreamComponents(tokenizer, new StopFilter(tokenizer, stopSet));
257         }
258       };
259 
260     Iterable<Input> keys = AnalyzingSuggesterTest.shuffle(
261         new Input("wizard of oz", 50)
262     );
263     FreeTextSuggester sug = new FreeTextSuggester(a, a, 3, (byte) 0x20);
264     sug.build(new InputArrayIterator(keys));
265     assertEquals("wizard _ oz/1.00",
266                  toString(sug.lookup("wizard of", 10)));
267 
268     // Falls back to unigram model, with backoff 0.4 times
269     // prop 0.5:
270     assertEquals("oz/0.20",
271                  toString(sug.lookup("wizard o", 10)));
272     a.close();
273   }
274 
275   // If the number of ending holes exceeds the ngrams window
276   // then there are no predictions, because ShingleFilter
277   // does not produce e.g. a hole only "_ _" token:
278   public void testTwoEndingHoles() throws Exception {
279     // Just deletes "of"
280     Analyzer a = new Analyzer() {
281         @Override
282         public TokenStreamComponents createComponents(String field) {
283           Tokenizer tokenizer = new MockTokenizer();
284           CharArraySet stopSet = StopFilter.makeStopSet("of");
285           return new TokenStreamComponents(tokenizer, new StopFilter(tokenizer, stopSet));
286         }
287       };
288 
289     Iterable<Input> keys = AnalyzingSuggesterTest.shuffle(
290         new Input("wizard of of oz", 50)
291     );
292     FreeTextSuggester sug = new FreeTextSuggester(a, a, 3, (byte) 0x20);
293     sug.build(new InputArrayIterator(keys));
294     assertEquals("",
295                  toString(sug.lookup("wizard of of", 10)));
296     a.close();
297   }
298 
299   private static Comparator<LookupResult> byScoreThenKey = new Comparator<LookupResult>() {
300     @Override
301     public int compare(LookupResult a, LookupResult b) {
302       if (a.value > b.value) {
303         return -1;
304       } else if (a.value < b.value) {
305         return 1;
306       } else {
307         // Tie break by UTF16 sort order:
308         return ((String) a.key).compareTo((String) b.key);
309       }
310     }
311   };
312 
313   public void testRandom() throws IOException {
314     String[] terms = new String[TestUtil.nextInt(random(), 2, 10)];
315     Set<String> seen = new HashSet<>();
316     while (seen.size() < terms.length) {
317       String token = TestUtil.randomSimpleString(random(), 1, 5);
318       if (!seen.contains(token)) {
319         terms[seen.size()] = token;
320         seen.add(token);
321       }
322     }
323 
324     Analyzer a = new MockAnalyzer(random());
325 
326     int numDocs = atLeast(10);
327     long totTokens = 0;
328     final String[][] docs = new String[numDocs][];
329     for(int i=0;i<numDocs;i++) {
330       docs[i] = new String[atLeast(100)];
331       if (VERBOSE) {
332         System.out.print("  doc " + i + ":");
333       }
334       for(int j=0;j<docs[i].length;j++) {
335         docs[i][j] = getZipfToken(terms);
336         if (VERBOSE) {
337           System.out.print(" " + docs[i][j]);
338         }
339       }
340       if (VERBOSE) {
341         System.out.println();
342       }
343       totTokens += docs[i].length;
344     }
345 
346     int grams = TestUtil.nextInt(random(), 1, 4);
347 
348     if (VERBOSE) {
349       System.out.println("TEST: " + terms.length + " terms; " + numDocs + " docs; " + grams + " grams");
350     }
351 
352     // Build suggester model:
353     FreeTextSuggester sug = new FreeTextSuggester(a, a, grams, (byte) 0x20);
354     sug.build(new InputIterator() {
355         int upto;
356 
357         @Override
358         public BytesRef next() {
359           if (upto == docs.length) {
360             return null;
361           } else {
362             StringBuilder b = new StringBuilder();
363             for(String token : docs[upto]) {
364               b.append(' ');
365               b.append(token);
366             }
367             upto++;
368             return new BytesRef(b.toString());
369           }
370         }
371 
372         @Override
373         public long weight() {
374           return random().nextLong();
375         }
376 
377         @Override
378         public BytesRef payload() {
379           return null;
380         }
381 
382         @Override
383         public boolean hasPayloads() {
384           return false;
385         }
386 
387         @Override
388         public Set<BytesRef> contexts() {
389           return null;
390         }
391 
392         @Override
393         public boolean hasContexts() {
394           return false;
395         }
396       });
397 
398     // Build inefficient but hopefully correct model:
399     List<Map<String,Integer>> gramCounts = new ArrayList<>(grams);
400     for(int gram=0;gram<grams;gram++) {
401       if (VERBOSE) {
402         System.out.println("TEST: build model for gram=" + gram);
403       }
404       Map<String,Integer> model = new HashMap<>();
405       gramCounts.add(model);
406       for(String[] doc : docs) {
407         for(int i=0;i<doc.length-gram;i++) {
408           StringBuilder b = new StringBuilder();
409           for(int j=i;j<=i+gram;j++) {
410             if (j > i) {
411               b.append(' ');
412             }
413             b.append(doc[j]);
414           }
415           String token = b.toString();
416           Integer curCount = model.get(token);
417           if (curCount == null) {
418             model.put(token, 1);
419           } else {
420             model.put(token, 1 + curCount);
421           }
422           if (VERBOSE) {
423             System.out.println("  add '" + token + "' -> count=" + model.get(token));
424           }
425         }
426       }
427     }
428 
429     int lookups = atLeast(100);
430     for(int iter=0;iter<lookups;iter++) {
431       String[] tokens = new String[TestUtil.nextInt(random(), 1, 5)];
432       for(int i=0;i<tokens.length;i++) {
433         tokens[i] = getZipfToken(terms);
434       }
435 
436       // Maybe trim last token; be sure not to create the
437       // empty string:
438       int trimStart;
439       if (tokens.length == 1) {
440         trimStart = 1;
441       } else {
442         trimStart = 0;
443       }
444       int trimAt = TestUtil.nextInt(random(), trimStart, tokens[tokens.length - 1].length());
445       tokens[tokens.length-1] = tokens[tokens.length-1].substring(0, trimAt);
446 
447       int num = TestUtil.nextInt(random(), 1, 100);
448       StringBuilder b = new StringBuilder();
449       for(String token : tokens) {
450         b.append(' ');
451         b.append(token);
452       }
453       String query = b.toString();
454       query = query.substring(1);
455 
456       if (VERBOSE) {
457         System.out.println("\nTEST: iter=" + iter + " query='" + query + "' num=" + num);
458       }
459 
460       // Expected:
461       List<LookupResult> expected = new ArrayList<>();
462       double backoff = 1.0;
463       seen = new HashSet<>();
464 
465       if (VERBOSE) {
466         System.out.println("  compute expected");
467       }
468       for(int i=grams-1;i>=0;i--) {
469         if (VERBOSE) {
470           System.out.println("    grams=" + i);
471         }
472 
473         if (tokens.length < i+1) {
474           // Don't have enough tokens to use this model
475           if (VERBOSE) {
476             System.out.println("      skip");
477           }
478           continue;
479         }
480 
481         if (i == 0 && tokens[tokens.length-1].length() == 0) {
482           // Never suggest unigrams from empty string:
483           if (VERBOSE) {
484             System.out.println("      skip unigram priors only");
485           }
486           continue;
487         }
488 
489         // Build up "context" ngram:
490         b = new StringBuilder();
491         for(int j=tokens.length-i-1;j<tokens.length-1;j++) {
492           b.append(' ');
493           b.append(tokens[j]);
494         }
495         String context = b.toString();
496         if (context.length() > 0) {
497           context = context.substring(1);
498         }
499         if (VERBOSE) {
500           System.out.println("      context='" + context + "'");
501         }
502         long contextCount;
503         if (context.length() == 0) {
504           contextCount = totTokens;
505         } else {
506           Integer count = gramCounts.get(i-1).get(context);
507           if (count == null) {
508             // We never saw this context:
509             backoff *= FreeTextSuggester.ALPHA;
510             if (VERBOSE) {
511               System.out.println("      skip: never saw context");
512             }
513             continue;
514           }
515           contextCount = count;
516         }
517         if (VERBOSE) {
518           System.out.println("      contextCount=" + contextCount);
519         }
520         Map<String,Integer> model = gramCounts.get(i);
521 
522         // First pass, gather all predictions for this model:
523         if (VERBOSE) {
524           System.out.println("      find terms w/ prefix=" + tokens[tokens.length-1]);
525         }
526         List<LookupResult> tmp = new ArrayList<>();
527         for(String term : terms) {
528           if (term.startsWith(tokens[tokens.length-1])) {
529             if (VERBOSE) {
530               System.out.println("        term=" + term);
531             }
532             if (seen.contains(term)) {
533               if (VERBOSE) {
534                 System.out.println("          skip seen");
535               }
536               continue;
537             }
538             String ngram = (context + " " + term).trim();
539             Integer count = model.get(ngram);
540             if (count != null) {
541               LookupResult lr = new LookupResult(ngram, (long) (Long.MAX_VALUE * (backoff * (double) count / contextCount)));
542               tmp.add(lr);
543               if (VERBOSE) {
544                 System.out.println("      add tmp key='" + lr.key + "' score=" + lr.value);
545               }
546             }
547           }
548         }
549 
550         // Second pass, trim to only top N, and fold those
551         // into overall suggestions:
552         Collections.sort(tmp, byScoreThenKey);
553         if (tmp.size() > num) {
554           tmp.subList(num, tmp.size()).clear();
555         }
556         for(LookupResult result : tmp) {
557           String key = result.key.toString();
558           int idx = key.lastIndexOf(' ');
559           String lastToken;
560           if (idx != -1) {
561             lastToken = key.substring(idx+1);
562           } else {
563             lastToken = key;
564           }
565           if (!seen.contains(lastToken)) {
566             seen.add(lastToken);
567             expected.add(result);
568             if (VERBOSE) {
569               System.out.println("      keep key='" + result.key + "' score=" + result.value);
570             }
571           }
572         }
573         
574         backoff *= FreeTextSuggester.ALPHA;
575       }
576 
577       Collections.sort(expected, byScoreThenKey);
578 
579       if (expected.size() > num) {
580         expected.subList(num, expected.size()).clear();
581       }
582 
583       // Actual:
584       List<LookupResult> actual = sug.lookup(query, num);
585 
586       if (VERBOSE) {
587         System.out.println("  expected: " + expected);
588         System.out.println("    actual: " + actual);
589       }
590 
591       assertEquals(expected.toString(), actual.toString());
592     }
593     a.close();
594   }
595 
596   private static String getZipfToken(String[] tokens) {
597     // Zipf-like distribution:
598     for(int k=0;k<tokens.length;k++) {
599       if (random().nextBoolean() || k == tokens.length-1) {
600         return tokens[k];
601       }
602     }
603     assert false;
604     return null;
605   }
606 
607   private static String toString(List<LookupResult> results) {
608     StringBuilder b = new StringBuilder();
609     for(LookupResult result : results) {
610       b.append(' ');
611       b.append(result.key);
612       b.append('/');
613       b.append(String.format(Locale.ROOT, "%.2f", ((double) result.value)/Long.MAX_VALUE));
614     }
615     return b.toString().trim();
616   }
617 }
618